import jax.numpy as jnp
from jax import jit
from python_basic.study.python_cookbook.ch9_metaprog.x91_timethis import timethis


def slow_f(x):
    return x ** 5 + x ** 4


fast_f = jit(slow_f)

timed_slow_f = timethis(slow_f)
timed_fast_f = timethis(fast_f)

x = jnp.ones((5000, 5000))
timed_fast_f(x)
timed_slow_f(x)
